{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1faf67e-758c-4344-9a68-72cb52c87c23",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from PIL import Image\n",
    "import paddle\n",
    "import plsc\n",
    "from plsc.models import vision_transformer\n",
    "from plsc.models.utils import tome\n",
    "from plsc.data import preprocess as transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5d52afc-8ad6-4352-940b-81ec2623838c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download models and assets\n",
    "!mkdir -p models\n",
    "if not os.path.exists('models/imagenet2012-ViT-B_16-224.pdparams'):\n",
    "    !wget -O ./models/imagenet2012-ViT-B_16-224.pdparams https://plsc.bj.bcebos.com/models/vit/v2.4/imagenet2012-ViT-B_16-224.pdparams\n",
    "\n",
    "!mkdir -p images\n",
    "if not os.path.exists('images/husky.png'):\n",
    "    !wget -O images/husky.png https://plsc.bj.bcebos.com/dataset/test_images/husky.png"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8896ebf0-7067-4caa-8ffb-d44ff20193d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = vision_transformer.ViT_base_patch16_224()\n",
    "model.load_pretrained('models/imagenet2012-ViT-B_16-224')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98a09ba5-8731-45d6-824f-1d084048c7fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "transform_val = transforms.Compose([\n",
    "    transforms.Resize(\n",
    "        size=256, interpolation=\"bicubic\", backend=\"pil\"),  # 3 is bicubic\n",
    "    transforms.CenterCrop(size=224),\n",
    "    transforms.NormalizeImage(\n",
    "        scale=1.0 / 255.0,\n",
    "        mean=[0.5, 0.5, 0.5],\n",
    "        std=[0.5, 0.5, 0.5],\n",
    "        order='hwc'),\n",
    "    transforms.ToCHWImage()\n",
    "])\n",
    "\n",
    "img = Image.open(\"images/husky.png\")\n",
    "img = transform_val(img)[None, ...]\n",
    "img = paddle.to_tensor(img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e621972-6f2d-4a4e-bc92-bfd716a10adb",
   "metadata": {},
   "outputs": [],
   "source": [
    "tome.apply_patch(model)\n",
    "\n",
    "model.r = 8\n",
    "#model.r = 16 #[25] * 8  # 8 / 24 layers\n",
    "\n",
    "logits = model(img)\n",
    "preds = paddle.nn.functional.softmax(logits)\n",
    "out = preds.argsort(descending=True)[..., :5]\n",
    "print(out)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
